##
from __future__ import annotations
import os
os.chdir('/data/l989o/deployed/a')
import sys
import shutil
import scvi
import scanpy as sc
import torch
from torch.utils.data import DataLoader
from data2 import SumFilteredDataset, file_path
import numpy as np
from tqdm import tqdm
import anndata as ad
import pandas as pd
from pytorch_lightning.loggers import TensorBoardLogger
import os
import matplotlib.pyplot as plt
from utils import memory, reproducible_random_choice
done
# COMPLETE_RUN = True
COMPLETE_RUN = True
N_EPOCHS_KL_WARMUP = 3
N_EPOCHS = 10
m = __name__ == '__main__'
##
if m and False:
# proxy for the DKFZ network
# https://stackoverflow.com/questions/34576665/setting-proxy-to-urllib-request-python3
os.environ["HTTP_PROXY"] = "http://193.174.53.86:80"
os.environ["HTTPS_PROXY"] = "https://193.174.53.86:80"
##
if m and False:
# to have a look at an existing dataset
import scvi.data
data = scvi.data.pbmc_dataset()
data
##
if m:
ds = SumFilteredDataset("train")
@memory.cache
def f_qpoxnqwida(ds):
l0 = []
l1 = []
for i, x in enumerate(tqdm(ds, "merging")):
l0.append(x)
l1.extend([i] * len(x))
return l0, l1
l0, l1 = f_qpoxnqwida(ds)
raw = np.concatenate(l0, axis=0)
raw = np.round(raw)
raw = raw.astype(np.int)
donor = np.array(l1)
a = ad.AnnData(raw)
merging: 1%| | 2/226 [00:00<00:20, 10.67it/s]
________________________________________________________________________________ [Memory] Calling __main__--data-l989o-deployed-a-<ipython-input-bb8ce796987d>.f_qpoxnqwida... f_qpoxnqwida(<data2.SumFilteredDataset object at 0x7eff6042e1f0>)
merging: 100%|██████████| 226/226 [00:18<00:00, 12.36it/s]
____________________________________________________f_qpoxnqwida - 19.0s, 0.3min
##
if m:
s = pd.Series(donor, index=a.obs.index)
a.obs["batch"] = s
##
if m:
scvi.data.setup_anndata(
a,
# this is probably meaningless (if not even penalizing) for unseen data as the batches are different
# categorical_covariate_keys=["batch"],
)
a
INFO No batch_key inputted, assuming all cells are same batch INFO No label_key inputted, assuming all cells have same label INFO Using data from adata.X INFO Computing library size prior per batch INFO Successfully registered anndata object containing 446738 cells, 39 vars, 1 batches, 1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. INFO Please do not further modify adata until model is trained.
##
if m:
# TRAIN = True
TRAIN = False
if TRAIN:
# vae = VAE(gene_dataset.nb_genes)
# trainer = UnsupervisedTrainer(
# vae,
# gene_dataset,
# train_size=0.90,
# use_cuda=use_cuda,
# frequency=5,
# )
# []:
# trainer.train(n_epochs=n_epochs, lr=lr)
model = scvi.model.SCVI(a)
##
from data2 import file_path
# the following code, as it is, doesn't work
# logger = TensorBoardLogger(save_dir=file_path("checkpoints"), name="scvi")
# BATCH_SIZE = 128
# indices = np.random.choice(len(a), BATCH_SIZE * 20, replace=False)
#
# train_loader_batch = DataLoader(
# a.X[indices, :],
# batch_size=BATCH_SIZE,
# num_workers=4,
# pin_memory=True,
# )
# model.train(train_size=1., logger=logger, val_dataloaders=train_loader_batch)
# model.__dict__
if m:
if TRAIN:
model.train(train_size=1.0, n_epochs=N_EPOCHS, n_epochs_kl_warmup=N_EPOCHS_KL_WARMUP)
f = file_path("scvi_model.scvi")
if os.path.isdir(f):
shutil.rmtree(f)
model.save(f)
else:
model = scvi.model.SCVI.load(file_path("scvi_model.scvi"), adata=a)
print(model.get_elbo())
INFO Using data from adata.X INFO Computing library size prior per batch INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels'] INFO Successfully registered anndata object containing 446738 cells, 39 vars, 1 batches, 1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. -187.95363649203713
##
if m:
z = model.get_latent_representation()
a.shape
z.shape
b = ad.AnnData(z)
random_indices = reproducible_random_choice(len(a), 10000)
aa = a[random_indices]
bb = b[random_indices]
def scanpy_compute(an: ad.AnnData):
sc.tl.pca(an)
print("computing neighbors... ", end="")
sc.pp.neighbors(an)
print("done")
print("computing umap... ", end="")
sc.tl.umap(an)
print("done")
print("computing louvain... ", end="")
sc.tl.louvain(an)
print("done")
##
if m and COMPLETE_RUN:
scanpy_compute(aa)
sc.pl.pca(aa, title="pca, raw data (sum)")
sc.pl.umap(aa, color="louvain", title="umap with louvain, scvi latent (sum)")
computing neighbors... done computing umap... done computing louvain... done
##
if m and COMPLETE_RUN:
scanpy_compute(bb)
sc.pl.pca(bb, title="pca, raw data (sum)")
sc.pl.umap(bb, color="louvain", title="umap with louvain, scvi latent (sum)")
computing neighbors... done computing umap... done computing louvain... done
##
from analyses.ab_images_vs_expression.ab_aa_expression_latent_samples import (
compare_clusters,
nearest_neighbors,
compute_knn,
)
Global seed set to 1234
if m and COMPLETE_RUN:
compare_clusters(aa, bb, description='"raw data (sum)" vs "scvi latent"')
compute_knn(aa)
compute_knn(bb)
nearest_neighbors(
nn_from=aa, plot_onto=bb, title='nn from "raw data (sum)" to "scvi latent"'
)
(19, 18)
done done
100%|██████████| 5/5 [00:00<00:00, 347.26it/s]
##
if m:
ds = SumFilteredDataset("validation")
@memory.cache
def f_ncqoi3faoj(ds):
l0 = []
l1 = []
for i, x in enumerate(tqdm(ds, "merging")):
l0.append(x)
l1.extend([i] * len(x))
return l0, l1
l0, l1 = f_ncqoi3faoj(ds)
raw = np.concatenate(l0, axis=0)
donor = np.array(l1)
a_val = ad.AnnData(raw)
merging: 1%| | 1/113 [00:00<00:12, 8.79it/s]
________________________________________________________________________________ [Memory] Calling __main__--data-l989o-deployed-a-<ipython-input-7f457070833c>.f_ncqoi3faoj... f_ncqoi3faoj(<data2.SumFilteredDataset object at 0x7efe7afe3d30>)
merging: 100%|██████████| 113/113 [00:08<00:00, 12.59it/s]
_____________________________________________________f_ncqoi3faoj - 9.3s, 0.2min
##
if m:
# note that here with are embedding without the batch information; if you want to look at batches it does not make
# sense to use another set except to the training one, since the train/val/test split is done by patient first
scvi.data.setup_anndata(
a_val,
)
z_val = model.get_latent_representation(a_val)
b_val = ad.AnnData(z_val)
random_indices_val = reproducible_random_choice(len(a_val), 10000)
aa_val = a_val[random_indices_val]
bb_val = b_val[random_indices_val]
INFO No batch_key inputted, assuming all cells are same batch INFO No label_key inputted, assuming all cells have same label INFO Using data from adata.X INFO Computing library size prior per batch INFO Successfully registered anndata object containing 219095 cells, 39 vars, 1 batches, 1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. INFO Please do not further modify adata until model is trained. WARNING Make sure the registered X field in anndata contains unnormalized count data.
/data/l989o/miniconda3/envs/spatial_uzh/lib/python3.8/site-packages/scvi/data/_anndata.py:793: UserWarning: adata.X does not contain unnormalized count data. Are you sure this is what you want? warnings.warn(
##
if m and COMPLETE_RUN:
scanpy_compute(aa_val)
scanpy_compute(bb_val)
computing neighbors... done computing umap... done computing louvain... done computing neighbors... done computing umap... done computing louvain... done
##
if m and COMPLETE_RUN:
sc.pl.pca(aa_val, title="pca, raw data (sum); validation set")
sc.pl.umap(
aa_val, color="louvain", title="umap with louvain, scvi latent (sum); valiation set"
)
sc.pl.pca(bb_val, title="pca, raw data (sum); valiation set")
sc.pl.umap(
bb_val, color="louvain", title="umap with louvain, scvi latent (sum); valiation set"
)
##
if m and COMPLETE_RUN:
merged = ad.AnnData.concatenate(bb, bb_val, batch_categories=["train", "validation"])
scanpy_compute(merged)
plt.figure()
ax = plt.gca()
sc.pl.umap(merged, color="batch", ax=ax, show=False)
plt.tight_layout()
plt.show()
computing neighbors... done computing umap... done computing louvain... done
##
if m:
size_factors = model.get_latent_library_size(a_val)
WARNING Make sure the registered X field in anndata contains unnormalized count data.
##
from data2 import AreaFilteredDataset
if m and COMPLETE_RUN:
area_ds = AreaFilteredDataset("validation")
l = []
for x in tqdm(area_ds, desc="merging"):
l.append(x)
areas = np.concatenate(l, axis=0)
merging: 100%|██████████| 113/113 [00:08<00:00, 12.90it/s]
##
if m and COMPLETE_RUN:
from scipy.stats import pearsonr
print(size_factors.shape)
print(areas.shape)
r, p = pearsonr(size_factors.ravel(), areas.ravel())
plt.figure()
plt.scatter(size_factors, areas, s=0.5)
plt.xlabel("latent size factors")
plt.ylabel("cell area")
plt.title(f"r: {r:0.2f} (p: {p:0.2f})")
plt.show()
(219095, 1) (219095, 1)
##
# imputation benchmark
from data2 import PerturbedCellDataset
def get_corrupted_entries(split: str):
ds = PerturbedCellDataset(split)
ds.perturb()
corrupted_entries = ds.corrupted_entries.numpy()
# just a hash
h = np.sum(np.concatenate(np.where(corrupted_entries == 1)))
print(f"corrupted entries hash ({split}):", h)
return corrupted_entries
if m:
ce_train = get_corrupted_entries("train")
ce_val = get_corrupted_entries("validation")
corrupted entries hash (train): 389620020511 corrupted entries hash (validation): 93848327662
##
if m:
ds = SumFilteredDataset("train")
@memory.cache
def f_ncqlliwr2(ds):
l0 = []
for i, x in enumerate(tqdm(ds, "merging")):
l0.append(x)
return l0
l0 = f_ncqlliwr2(ds)
raw = np.concatenate(l0, axis=0)
raw[ce_train] = 0
raw = np.round(raw)
raw = raw.astype(np.int)
a_perturbed = ad.AnnData(raw)
merging: 1%| | 2/226 [00:00<00:20, 11.10it/s]
________________________________________________________________________________ [Memory] Calling __main__--data-l989o-deployed-a-<ipython-input-2d8cc6047e99>.f_ncqlliwr2... f_ncqlliwr2(<data2.SumFilteredDataset object at 0x7eff54750b20>)
merging: 100%|██████████| 226/226 [00:17<00:00, 12.95it/s]
_____________________________________________________f_ncqlliwr2 - 17.6s, 0.3min
##
if m:
scvi.data.setup_anndata(a_perturbed)
# TRAIN_PERTURBED = True
TRAIN_PERTURBED = False
if TRAIN_PERTURBED:
# to navigate there with PyCharm and set a breakpoint on a warning (haven't done yet)
import scvi.core.distributions
model = scvi.model.SCVI(a_perturbed)
if TRAIN_PERTURBED:
model.train(train_size=1.0, n_epochs=N_EPOCHS, n_epochs_kl_warmup=N_EPOCHS_KL_WARMUP)
f = file_path("scvi_model_perturbed.scvi")
if os.path.isdir(f):
shutil.rmtree(f)
model.save(f)
else:
model = scvi.model.SCVI.load(file_path("scvi_model_perturbed.scvi"), adata=a)
print(model.get_elbo())
INFO No batch_key inputted, assuming all cells are same batch INFO No label_key inputted, assuming all cells have same label INFO Using data from adata.X INFO Computing library size prior per batch INFO Successfully registered anndata object containing 446738 cells, 39 vars, 1 batches, 1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. INFO Please do not further modify adata until model is trained. INFO Using data from adata.X INFO Computing library size prior per batch INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels'] INFO Successfully registered anndata object containing 446738 cells, 39 vars, 1 batches, 1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. -195.80324340944367
##
if m:
x_val_perturbed = a_val.X.copy()
x_val_perturbed[ce_val] = 0
a_val_perturbed = ad.AnnData(x_val_perturbed)
##
if m:
p = model.get_likelihood_parameters(a_val_perturbed)
from scvi.core.distributions import ZeroInflatedNegativeBinomial
x_val_perturbed_pred = ZeroInflatedNegativeBinomial(
mu=torch.tensor(p["mean"]),
theta=torch.tensor(p["dispersions"]),
zi_logits=torch.tensor(p["dropout"]),
).mean.numpy()
INFO Input adata not setup with scvi. attempting to transfer anndata setup INFO .obs[_scvi_batch] not found in target, assuming every cell is same category INFO .obs[_scvi_labels] not found in target, assuming every cell is same category INFO Using data from adata.X INFO Computing library size prior per batch INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels'] INFO Successfully registered anndata object containing 219095 cells, 39 vars, 1 batches, 1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. WARNING Make sure the registered X field in anndata contains unnormalized count data.
/data/l989o/miniconda3/envs/spatial_uzh/lib/python3.8/site-packages/scvi/data/_anndata.py:793: UserWarning: adata.X does not contain unnormalized count data. Are you sure this is what you want? warnings.warn(
##
if m:
# ne: normal entries
ne_train = np.logical_not(ce_train)
ne_val = np.logical_not(ce_val)
x_val = a_val.X.copy()
uu0 = x_val_perturbed_pred[ce_val]
uu1 = x_val[ce_val]
vv0 = x_val_perturbed_pred[ne_val]
vv1 = x_val[ne_val]
##
if m:
fig = plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.hist(np.abs(uu0 - uu1))
m = np.mean(np.abs(uu0 - uu1))
plt.title(f"scores for imputed entries\nmean: {m:0.2f}")
plt.yscale("log")
plt.subplot(1, 2, 2)
plt.hist(np.abs(vv0 - vv1))
m = np.mean(np.abs(vv0 - vv1))
plt.title(f"control: normal entries\nmean: {m:0.2f}")
plt.yscale("log")
fig.suptitle("abs(original vs predicted)")
plt.tight_layout()
plt.show()
##
if m:
from analyses.aa_reconstruction_benchmark.aa_ad_reconstruction import Prediction, Space
s = np.abs(uu0 - uu1)
t = np.abs(vv0 - vv1)
Prediction.welch_t_test(s, t)
# the printed p-value is very close to 0
# conclusion: the score for imputed data is worse than the one from non-perturbed data; this is expected and the
# alternative case would have been a model whose scores are both bad because it is not properly trained
##
welch's t test: p_value = 0.0
if m:
scvi_predictions = Prediction(
original=x_val,
corrupted_entries=ce_val,
predictions_from_perturbed=x_val_perturbed_pred,
space=Space.raw_sum,
name='scVI',
split='validation'
)
scvi_predictions.plot_reconstruction()
scvi_predictions.plot_scores()
channels: 100%|██████████| 39/39 [00:34<00:00, 1.12it/s]
##
if m:
p = scvi_predictions.transform_to(Space.scaled_mean)
p.name = 'scVI scaled'
p.plot_reconstruction()
p.plot_scores()
applying transformation from raw_sum to raw_mean applying transformation from raw_mean to asinh_mean applying transformation from asinh_mean to scaled_mean applying transformation from raw_sum to raw_mean applying transformation from raw_mean to asinh_mean applying transformation from asinh_mean to scaled_mean
channels: 100%|██████████| 39/39 [00:36<00:00, 1.07it/s]